(in-package :progalgs)

(defparameter *b64-dict*
  (coerce (append (loop :for ch :from (char-code #\A) :to (char-code #\Z)
                        :collect (code-char ch))
                  (loop :for ch :from (char-code #\a) :to (char-code #\z)
                        :collect (code-char ch))
                  (loop :for ch :from (char-code #\0) :to (char-code #\9)
                        :collect (code-char ch))
                  '(#\+ #\/ #\=))
          'simple-vector))

(defun b64-encode (in out)
  (let ((key 0)
        (limit 6))
    (flet ((fill-key (byte off beg limit)
             (setf (ldb (byte limit off) key)
                   (ldb (byte limit beg) byte))
             (setf off (- 6 beg)))
           (emit1 (k)
             (write-byte (char-code (svref *b64-dict* k)) out)))
      (loop :for byte := (read-byte in nil) :while byte :do
        (let ((beg (- 8 limit)))
          (fill-key byte 0 beg limit)
          (emit1 key)
          (fill-key byte (setf limit (- 6 beg)) 0 beg)
          (when (= 6 beg)
            (emit1 key)
            (setf limit 6))))
      (when (< limit 6)
        (setf (ldb (byte limit 0) key)
              (ldb (byte limit 0) 0))
        (emit1 key)
        (loop :repeat (ceiling limit 2) :do
          (emit1 64))))))

(defun b64str (str)
  (let ((in (flex:make-in-memory-input-stream  (map 'vector 'char-code str)))
        (out (flex:make-in-memory-output-stream)))
    (b64-encode in out)
    (map 'string 'code-char (rtl:? out 'vector))))

(deftest base64 ()
  (should be rtl:blankp (b64str ""))
  (should be string= "TWFu" (b64str "Man"))
  (should be string= "TWFuIA==" (b64str "Man "))
  (should be string= "TWFuIGk=" (b64str "Man i")))

(defun huffman-encode (envocab str)
  (let ((rez (make-array 0 :element-type 'bit :adjustable t :fill-pointer t)))
    (rtl:dovec (char str)
               (rtl:dovec (bit (rtl:? envocab char))
                          (vector-push-extend bit rez)))
    rez))

(defun huffman-decode (devocab vec)
  (let (rez)
    (dotimes (i (length vec))
      (dotimes (j (- (length vec) i))
        (rtl:when-it (rtl:? devocab (rtl:slice vec i (+ i j 1)))
                     (push rtl:it rez)
                     (incf i j)
                     (return))))
    (coerce (reverse rez) 'string)))

(defun huffman-vocabs (str)
  ;; here we assume more than a single unique character in STR
  (let ((counts (make-hash-table))
        (q (make-heap :op '< :key 'rt))
        (envocab (make-hash-table))
        (devocab (make-hash-table :test 'equal)))  ; bit-vectors as keys require
                                                   ; equal comparison
    ;; count character frequencies
    (rtl:dovec (char str)
      (incf (gethash char counts 0)))  ; here, we use the default third argument
                                       ; of get# with the value of 0
    ;; heapsort the characters based on their frequency
    (rtl:dotable (char count counts)
      (heap-push (rtl:pair char count) q))
    ;; build the tree
    (dotimes (i (1- (heap-size q)))
      (rtl:with (((lt cl) (heap-pop q))
                 ((rt cr) (heap-pop q)))
        (heap-push (rtl:pair (list lt rt) (+ cl cr))
                   q)))
    ;; traverse the tree in DFS manner
    ;; encoding the path to each leaf node as a bit-vector
    (labels ((dfs (node &optional (level 0) path)
               (if (listp node)
                   (progn
                     (dfs (rtl:lt node) (1+ level) (cons 0 path))
                     (dfs (rtl:rt node) (1+ level) (cons 1 path)))
                   (let ((vec (make-array level :element-type 'bit
                                                :initial-contents (reverse path))))
                     (setf (rtl:? envocab node) vec
                           (rtl:? devocab vec) node)))))
      (dfs (lt (heap-pop q))))
    (list envocab devocab)))

(defun huffman-tables (hts envocab)
  (declare (optimize sb-c::instrument-consing))
  (mapcar (lambda (ht)
            (let ((rez (make-hash-table :test 'equal)))
              (rtl:dotable (str logprob ht)
                           (setf (rtl:? rez (huffman-encode envocab str)) logprob))
              rez))
          hts))

(defun huffman-encode2 (envocab str)
  (let ((vecs (map 'vector (lambda (ch) (rtl:get# ch envocab))
                   str))
        (total-size 0))
    (rtl:dovec (vec vecs)
               (incf total-size (length vec)))
    (let ((rez (make-array total-size :element-type 'bit))
          (i 0))
      (rtl:dovec (vec vecs)
                 (let ((size (length vec)))
                   (setf (subseq rez i) vec)
                   (incf i size)))
      rez)))

(defun huffman-encode3 (envocab str)
  (let ((rez (make-array 0 :element-type 'bit :adjustable t :fill-pointer t)))
    (rtl:dovec (char str)
               ;; here, we have changed the hash-table to a jump-table
               (rtl:dovec (bit (svref envocab (char-code char)))
                          (vector-push-extend bit rez)))
    rez))

(defun find-shortest-bitvec (lo hi)
  (let ((rez (make-array 0 :element-type 'bit :adjustable t :fill-pointer t)))
    (loop
      (rtl:with ((lod lof (floor (* lo 2)))
                 (hid hif (floor (* hi 2))))
                (when (or (zerop lof)
                          (zerop hif)
                          (/= lod hid))
                  (vector-push-extend hid rez)
                  (return))
                (vector-push-extend lod rez)
                (setf lo lof
                      hi hif)))
    rez))

(deftest find-shortest-bitvec ()
  (should be equalp #*01 (find-shortest-bitvec 0.214285714 0.357142857)))

(defun arithm-encode (envocab message)
  (let ((lo 0.0)
        (hi 1.0))
    (rtl:dovec (char message)
               (let ((coef (- hi lo)))
                 (rtl:dotable (ch prob envocab)
                              (let ((off (* prob coef)))
                                (when (eql char ch)
                                  (setf hi (+ lo off))
                                  (return))
                                (incf lo off)))))
    (find-shortest-bitvec lo hi)))

(defun arithm-encode-correct (envocab message)
  (let ((lo 0)
        (hi (1- (expt 2 32)))
        (pending-bits 0)
        (rez (make-array 0 :element-type 'bit :adjustable t :fill-pointer t)))
    (flet ((emit-bit (bit)
             (vector-push-extend bit rez)
             (let ((pbit (if (zerop bit) 1 0)))
               (loop :repeat pending-bits :do (vector-push-extend pbit rez))
               (setf pending-bits 0))))
      (rtl:dovec (char message)
                 (rtl:with ((range (- hi lo -1))
                            ((plo phi) (rtl:? envocab char)))
                           (psetf lo (round (+ lo (* plo range)))
                                  hi (round (+ lo (* phi range) -1)))
                           (loop
                             (cond ((< hi #.(expt 2 31))
                                    (emit-bit 0))
                                   ((>= lo #.(expt 2 31))
                                    (emit-bit 1)
                                    (decf lo #.(expt 2 31))
                                    (decf hi #.(expt 2 31)))
                                   ((and (>= lo #.(expt 2 30))
                                         (< hi (+ #.(expt 2 30) #.(expt 2 31))))
                                    (decf lo #.(expt 2 30))
                                    (decf hi #.(expt 2 30))
                                    (incf pending-bits))
                                   (t (return)))
                             (psetf lo (mask32 (ash lo 1))
                                    hi (mask32 (1+ (ash hi 1)))))))
      (incf pending-bits)
      (emit-bit (if (< lo #.(expt 2 30)) 0 1)))
    rez))

(defun mask32 (num)
  ;; this utility is used to confine the number in 32 bits
  (logand num #.(1- (expt 2 32))))

(defun bitvec->int (bits)
  (reduce (lambda (bit1 bit2) (+ (ash bit1 1) bit2))
          bits))

(defun arithm-decode (dedict vec size)
  (rtl:with ((len (length vec))
             (lo 0)
             (hi (1- (expt 2 32)))
             (val (bitvec->int (subseq vec 0 (min 32 len))))
             (off 32)
             (rez (make-string size)))
    (dotimes (i size)
      (rtl:with ((range (- hi lo -1))
                 (prob (/ (- val lo) range)))
        (rtl:dotable (char r dedict)
          (rtl:with (((plo phi) r))
            (when (>= phi prob)
              (psetf (char rez i) char
                     lo (round (+ lo (* plo range)))
                     hi (round (+ lo (* phi range) -1)))
              (return))))
        (loop
          (cond ((< hi #.(expt 2 31))
                 ;; do nothing
                 )
                ((>= lo #.(expt 2 31))
                 (decf lo #.(expt 2 31))
                 (decf hi #.(expt 2 31))
                 (decf val #.(expt 2 31)))
                ((and (>= lo #.(expt 2 30))
                      (< hi #.(* 3 (expt 2 30))))
                 (decf lo #.(expt 2 30))
                 (decf hi #.(expt 2 30))
                 (decf val #.(expt 2 30)))
                (t
                 (return)))
          (psetf lo (mask32 (ash lo 1))
                 hi (mask32 (1+ (ash hi 1)))
                 val (mask32 (+ (ash val 1)
                                 (if (< off len)
                                     (aref vec off)
                                     0)))
                 off (1+ off)))))
    rez))

(deftest compression ()
  (rtl:with (((dict1 dict2)
              (mapcar (lambda (d)
                        (let ((dict (make-hash-table)))
                          (loop :for (k v) :on d :by #'cddr
                                :do (rtl:sethash k dict v))
                          dict))
                      '((#\e 1/14
                         #\a 1/14
                         #\h 1/14
                         #\i 2/14
                         #\s 3/14
                         #\t 3/14
                         #\Space 3/14)
                        (#\e (0 1/14)
                         #\a (1/14 1/7)
                         #\h (1/7 3/14)
                         #\i (3/14 5/14)
                         #\s (5/14 4/7)
                         #\t (4/7 11/14)
                         #\Space (11/14 1))))))
    (should be equal #*100110110100001110000001
            (arithm-encode dict1 "this is a test"))
    (should be equal #*10011011010000111000001101010110010101
          (arithm-encode-correct dict2 "this is a test"))
    (should be string= "this is a test"
            (arithm-decode dict2 (arithm-encode-correct dict2 "this is a test")
                           14))))
